{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": false
   },
   "source": [
    "# 网络中的网络（NiN）\n",
    ":label:`sec_nin`\n",
    "\n",
    "LeNet、AlexNet 和 VGG 都有一个共同的设计模式：通过一系列的卷积层与池化层来提取空间结构特征；然后通过全连接层对特征的表征进行处理。\n",
    "AlexNet 和 VGG 对 LeNet 的改进主要在于如何扩大和加深这两个模块。\n",
    "或者，可以想象在这个过程的早期使用全连接层。\n",
    "然而，如果使用稠密层了，可能会完全放弃表征的空间结构。\n",
    "*网络中的网络* (*NiN*) 提供了一个非常简单的解决方案：在每个像素的通道上分别使用多层感知机 :cite:`Lin.Chen.Yan.2013`\n",
    "\n",
    "## (**NiN块**)\n",
    "\n",
    "回想一下，卷积层的输入和输出由四维张量组成，张量的每个轴分别对应样本、通道、高度和宽度。\n",
    "另外，全连接层的输入和输出通常是分别对应于样本和特征的二维张量。\n",
    "NiN 的想法是在每个像素位置（针对每个高度和宽度）应用一个全连接层。\n",
    "如果我们将权重连接到每个空间位置，我们可以将其视为 $1\\times 1$ 卷积层（如 :numref:`sec_channels` 中所述），或作为在每个像素位置上独立作用的全连接层。\n",
    "从另一个角度看，即将空间维度中的每个像素视为单个样本，将通道维度视为不同特征（feature）。\n",
    "\n",
    ":numref:`fig_nin` 说明了 VGG 和 NiN 及它们的块之间主要结构差异。\n",
    "NiN 块以一个普通卷积层开始，后面是两个 $1\\times 1$ 的卷积层。这两个$1\\times 1$ 卷积层充当带有 ReLU 激活函数的逐像素全连接层。\n",
    "第一层的卷积窗口形状通常由用户设置。\n",
    "随后的卷积窗口形状固定为 $1 \\times 1$。\n",
    "\n",
    "![对比 VGG 和 NiN 及它们的块之间主要结构差异。](../img/nin.svg)\n",
    ":width:`600px`\n",
    ":label:`fig_nin`\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "import paddle\n",
    "import paddle.nn as nn\n",
    "\n",
    "def nin_block(in_channels, out_channels, kernel_size, strides, padding):\n",
    "    return nn.Sequential(\n",
    "        nn.Conv2D(in_channels, out_channels, kernel_size, strides, padding),\n",
    "        nn.ReLU(), nn.Conv2D(out_channels, out_channels, kernel_size=1),\n",
    "        nn.ReLU(), nn.Conv2D(out_channels, out_channels, kernel_size=1),\n",
    "        nn.ReLU())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": false
   },
   "source": [
    "## [**NiN模型**]\n",
    "\n",
    "最初的 NiN 网络是在 AlexNet 后不久提出的，显然从中得到了一些启示。\n",
    "NiN使用窗口形状为 $11\\times 11$、$5\\times 5$ 和 $3\\times 3$的卷积层，输出通道数量与 AlexNet 中的相同。\n",
    "每个 NiN 块后有一个最大池化层，池化窗口形状为 $3\\times 3$，步幅为 2。\n",
    "\n",
    "NiN 和 AlexNet 之间的一个显著区别是 NiN 完全取消了全连接层。\n",
    "相反，NiN 使用一个 NiN块，其输出通道数等于标签类别的数量。最后放一个 *全局平均池化层*（global average pooling layer），生成一个多元逻辑向量（logits）。NiN 设计的一个优点是，它显著减少了模型所需参数的数量。然而，在实践中，这种设计有时会增加训练模型的时间。\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "NiN = nn.Sequential(\n",
    "    nin_block(1, 96, kernel_size=11, strides=4, padding=0),\n",
    "    nn.MaxPool2D(3, stride=2),\n",
    "    nin_block(96, 256, kernel_size=5, strides=1, padding=2),\n",
    "    nn.MaxPool2D(3, stride=2),\n",
    "    nin_block(256, 384, kernel_size=3, strides=1, padding=1),\n",
    "    nn.MaxPool2D(3, stride=2), nn.Dropout(0.5),\n",
    "    # 标签类别数是10\n",
    "    nin_block(384, 10, kernel_size=3, strides=1, padding=1),\n",
    "    nn.AdaptiveAvgPool2D((1, 1)),\n",
    "    # 将四维的输出转成二维的输出，其形状为(批量大小, 10)\n",
    "    nn.Flatten())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": false
   },
   "source": [
    "我们创建一个数据样本来[**查看每个块的输出形状**]。\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "-------------------------------------------------------------------------------\n",
      "   Layer (type)         Input Shape          Output Shape         Param #    \n",
      "===============================================================================\n",
      "     Conv2D-1        [[1, 1, 224, 224]]    [1, 96, 54, 54]        11,712     \n",
      "      ReLU-1         [[1, 96, 54, 54]]     [1, 96, 54, 54]           0       \n",
      "     Conv2D-2        [[1, 96, 54, 54]]     [1, 96, 54, 54]         9,312     \n",
      "      ReLU-2         [[1, 96, 54, 54]]     [1, 96, 54, 54]           0       \n",
      "     Conv2D-3        [[1, 96, 54, 54]]     [1, 96, 54, 54]         9,312     \n",
      "      ReLU-3         [[1, 96, 54, 54]]     [1, 96, 54, 54]           0       \n",
      "    MaxPool2D-1      [[1, 96, 54, 54]]     [1, 96, 26, 26]           0       \n",
      "     Conv2D-4        [[1, 96, 26, 26]]     [1, 256, 26, 26]       614,656    \n",
      "      ReLU-4         [[1, 256, 26, 26]]    [1, 256, 26, 26]          0       \n",
      "     Conv2D-5        [[1, 256, 26, 26]]    [1, 256, 26, 26]       65,792     \n",
      "      ReLU-5         [[1, 256, 26, 26]]    [1, 256, 26, 26]          0       \n",
      "     Conv2D-6        [[1, 256, 26, 26]]    [1, 256, 26, 26]       65,792     \n",
      "      ReLU-6         [[1, 256, 26, 26]]    [1, 256, 26, 26]          0       \n",
      "    MaxPool2D-2      [[1, 256, 26, 26]]    [1, 256, 12, 12]          0       \n",
      "     Conv2D-7        [[1, 256, 12, 12]]    [1, 384, 12, 12]       885,120    \n",
      "      ReLU-7         [[1, 384, 12, 12]]    [1, 384, 12, 12]          0       \n",
      "     Conv2D-8        [[1, 384, 12, 12]]    [1, 384, 12, 12]       147,840    \n",
      "      ReLU-8         [[1, 384, 12, 12]]    [1, 384, 12, 12]          0       \n",
      "     Conv2D-9        [[1, 384, 12, 12]]    [1, 384, 12, 12]       147,840    \n",
      "      ReLU-9         [[1, 384, 12, 12]]    [1, 384, 12, 12]          0       \n",
      "    MaxPool2D-3      [[1, 384, 12, 12]]     [1, 384, 5, 5]           0       \n",
      "     Dropout-1        [[1, 384, 5, 5]]      [1, 384, 5, 5]           0       \n",
      "     Conv2D-10        [[1, 384, 5, 5]]      [1, 10, 5, 5]         34,570     \n",
      "      ReLU-10         [[1, 10, 5, 5]]       [1, 10, 5, 5]            0       \n",
      "     Conv2D-11        [[1, 10, 5, 5]]       [1, 10, 5, 5]           110      \n",
      "      ReLU-11         [[1, 10, 5, 5]]       [1, 10, 5, 5]            0       \n",
      "     Conv2D-12        [[1, 10, 5, 5]]       [1, 10, 5, 5]           110      \n",
      "      ReLU-12         [[1, 10, 5, 5]]       [1, 10, 5, 5]            0       \n",
      "AdaptiveAvgPool2D-1   [[1, 10, 5, 5]]       [1, 10, 1, 1]            0       \n",
      "     Flatten-1        [[1, 10, 1, 1]]          [1, 10]               0       \n",
      "===============================================================================\n",
      "Total params: 1,992,166\n",
      "Trainable params: 1,992,166\n",
      "Non-trainable params: 0\n",
      "-------------------------------------------------------------------------------\n",
      "Input size (MB): 0.19\n",
      "Forward/backward pass size (MB): 24.20\n",
      "Params size (MB): 7.60\n",
      "Estimated Total Size (MB): 31.99\n",
      "-------------------------------------------------------------------------------\n",
      "\n",
      "{'total_params': 1992166, 'trainable_params': 1992166}\n"
     ]
    }
   ],
   "source": [
    "print(paddle.summary(NiN, (1, 1, 224, 224)))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": false
   },
   "source": [
    "## [**训练模型**]\n",
    "\n",
    "和以前一样，我们使用 Fashion-MNIST 来训练模型。训练 NiN 与训练 AlexNet、VGG时相似。\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "The loss value printed in the log is the current step, and the metric is the average value of previous steps.\n",
      "Epoch 1/10\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/paddle/fluid/layers/utils.py:77: DeprecationWarning: Using or importing the ABCs from 'collections' instead of from 'collections.abc' is deprecated, and in 3.8 it will stop working\n",
      "  return (isinstance(seq, collections.Sequence) and\n"
     ]
    }
   ],
   "source": [
    "import paddle.vision.transforms as T\n",
    "from paddle.vision.datasets import FashionMNIST\n",
    "\n",
    "lr, num_epochs, batch_size = 0.001, 10, 128\n",
    "\n",
    "# 数据集处理\n",
    "transform = T.Compose([\n",
    "    T.Resize(224),\n",
    "    T.Transpose(),\n",
    "    T.Normalize([127.5], [127.5]),\n",
    "])\n",
    "# 数据集定义\n",
    "train_dataset = FashionMNIST(mode='train', transform=transform)\n",
    "val_dataset = FashionMNIST(mode='test', transform=transform)\n",
    "\n",
    "# 模型设置\n",
    "model = paddle.Model(NiN)\n",
    "model.prepare(\n",
    "    paddle.optimizer.Adam(learning_rate=lr, parameters=model.parameters()),\n",
    "    paddle.nn.CrossEntropyLoss(),\n",
    "    paddle.metric.Accuracy(topk=(1, 5)))\n",
    "# 模型训练\n",
    "callback = paddle.callbacks.VisualDL(log_dir='visualdl_log_dir')\n",
    "model.fit(train_dataset, val_dataset, epochs=num_epochs, batch_size=batch_size, log_freq=200, callbacks=callback)\n",
    "# model.fit(train_dataset, val_dataset, epochs=1, batch_size=batch_size, log_freq=1)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": false
   },
   "source": [
    "## 小结\n",
    "\n",
    "* NiN使用由一个卷积层和多个 $1\\times 1$ 卷积层组成的块。该块可以在卷积神经网络中使用，以允许更多的每像素非线性。\n",
    "* NiN去除了容易造成过拟合的全连接层，将它们替换为全局平均池化层（即在所有位置上进行求和）。该池化层通道数量为所需的输出数量（例如，Fashion-MNIST的输出为10）。\n",
    "* 移除全连接层可减少过拟合，同时显著减少NiN的参数。\n",
    "* NiN的设计影响了许多后续卷积神经网络的设计。\n",
    "\n",
    "## 练习\n",
    "\n",
    "1. 调整NiN的超参数，以提高分类准确性。\n",
    "1. 为什么NiN块中有两个 $1\\times 1$ 卷积层？删除其中一个，然后观察和分析实验现象。\n",
    "1. 计算NiN的资源使用情况。\n",
    "    1. 参数的数量是多少？\n",
    "    1. 计算量是多少？\n",
    "    1. 训练期间需要多少显存？\n",
    "    1. 预测期间需要多少显存？\n",
    "1. 一次性直接将 $384 \\times 5 \\times 5$ 的表示缩减为 $10 \\times 5 \\times 5$ 的表示，会存在哪些问题？\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": false
   },
   "source": [
    "[Discussions](https://discuss.d2l.ai/t/1869)\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "PaddlePaddle 2.1.0 (Python 3.5)",
   "language": "python",
   "name": "py35-paddle1.2.0"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
